package com.gitee.cashzhang27.test.cloud.oauth.auth.server.extended.util;


import static org.springframework.security.jwt.codec.Codecs.b64UrlDecode;
import static org.springframework.security.jwt.codec.Codecs.b64UrlEncode;
import static org.springframework.security.jwt.codec.Codecs.concat;
import static org.springframework.security.jwt.codec.Codecs.utf8Decode;
import static org.springframework.security.jwt.codec.Codecs.utf8Encode;

import cn.hutool.crypto.SecureUtil;
import cn.hutool.crypto.asymmetric.KeyType;
import cn.hutool.crypto.asymmetric.RSA;
import java.nio.CharBuffer;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import org.apache.commons.lang.StringUtils;
import org.springframework.security.jwt.BinaryFormat;
import org.springframework.security.jwt.Jwt;
import org.springframework.security.jwt.crypto.sign.SignatureVerifier;
import org.springframework.security.jwt.crypto.sign.Signer;
import org.springframework.util.Base64Utils;

/**
 * @author Cash Zhang
 * @version v1.0
 * @since 2019/02/10 15:40
 */
public class JwtHelper {

  static byte[] PERIOD = utf8Encode(".");

  /**
   * Creates a token from an encoded token string.
   *
   * @param token the (non-null) encoded token (three Base-64 encoded strings separated by "."
   * characters)
   */
  public static Jwt decode(String token) {
    int firstPeriod = token.indexOf('.');
    int lastPeriod = token.lastIndexOf('.');

    if (firstPeriod <= 0 || lastPeriod <= firstPeriod) {
      throw new IllegalArgumentException("JWT must have 3 tokens");
    }
    CharBuffer buffer = CharBuffer.wrap(token, 0, firstPeriod);
    // TODO: Use a Reader which supports CharBuffer
    JwtHeader header = JwtHeaderHelper
        .create(buffer.toString());

    buffer.limit(lastPeriod).position(firstPeriod + 1);
    byte[] claims = b64UrlDecode(buffer);
    boolean emptyCrypto = lastPeriod == token.length() - 1;

    byte[] crypto;

    if (emptyCrypto) {
      if (!"none".equals(header.parameters.alg)) {
        throw new IllegalArgumentException(
            "Signed or encrypted token must have non-empty crypto segment");
      }
      crypto = new byte[0];
    } else {
      buffer.limit(token.length()).position(lastPeriod + 1);
      crypto = b64UrlDecode(buffer);
    }
    return new JwtImpl(header, claims, crypto);
  }

  public static Jwt decodeAndVerify(String token, SignatureVerifier verifier) {
    Jwt jwt = decode(token);
    jwt.verifySignature(verifier);

    return jwt;
  }

  public static Map<String, String> headers(String token) {
    JwtImpl jwt = (JwtImpl) decode(token);
    Map<String, String> map = new LinkedHashMap<>(jwt.header.parameters.map);
    map.put("alg", jwt.header.parameters.alg);
    if (jwt.header.parameters.typ != null) {
      map.put("typ", jwt.header.parameters.typ);
    }
    return map;
  }

  public static String decryptStr(CharSequence content, Map<String, String> rsaKeyPair) {
    String rsaFtfPriKey = rsaKeyPair.get("rsaFtfPriKey");
    String rsaFtfPubKey = rsaKeyPair.get("rsaFtfPubKey");
    byte[] claims = utf8Encode(content);
    RSA ftfRsa = SecureUtil.rsa(rsaFtfPriKey, rsaFtfPubKey);
    byte[] decodeClaims = Base64Utils.decode(claims);
    byte[] ftfDecrypt = ftfRsa.decrypt(decodeClaims, KeyType.PublicKey);

    String rsaAppPubKey = rsaKeyPair.get("rsaAppPubKey");
    String rsaAppPriKey = rsaKeyPair.get("rsaAppPriKey");

    RSA appRsa = SecureUtil.rsa(rsaAppPriKey, rsaAppPubKey);
    byte[] appEncrypt = appRsa.decrypt(ftfDecrypt, KeyType.PrivateKey);
    return utf8Decode(appEncrypt);
  }

  public static String encryptStr(CharSequence content, Map<String, String> rsaKeyPair) {
    String rsaAppPubKey = rsaKeyPair.get("rsaAppPubKey");
    String rsaAppPriKey = rsaKeyPair.get("rsaAppPriKey");
    byte[] claims = utf8Encode(content);
    RSA appRsa = SecureUtil.rsa(rsaAppPriKey, rsaAppPubKey);
    byte[] appEncrypt = appRsa.encrypt(claims, KeyType.PublicKey);

    String rsaFtfPriKey = rsaKeyPair.get("rsaFtfPriKey");
    String rsaFtfPubKey = rsaKeyPair.get("rsaFtfPubKey");
    RSA ftfRsa = SecureUtil.rsa(rsaFtfPriKey, rsaFtfPubKey);
    byte[] ftfEncrypt = ftfRsa.encrypt(appEncrypt, KeyType.PrivateKey);
    return Base64Utils.encodeToString((ftfEncrypt));
  }

  public static Jwt encode(CharSequence content, Signer signer) {
    return encode(content, signer, Collections.emptyMap());
  }

  public static Jwt encode(CharSequence content, Signer signer,
      Map<String, String> headers) {
    JwtHeader header = JwtHeaderHelper
        .create(signer, headers);
    byte[] claims = utf8Encode(content);

    byte[] crypto = signer
        .sign(concat(b64UrlEncode(header.bytes()), PERIOD, b64UrlEncode(claims)));
    return new JwtImpl(header, claims, crypto);
  }
}

/**
 * Helper object for JwtHeader.
 *
 * Handles the JSON parsing and serialization.
 */
class JwtHeaderHelper {

  static JwtHeader create(String header) {
    byte[] bytes = b64UrlDecode(header);
    return new JwtHeader(bytes, parseParams(bytes));
  }

  static JwtHeader create(Signer signer, Map<String, String> params) {
    Map<String, String> map = new LinkedHashMap<>(params);
    map.put("alg", JwtAlgorithms.sigAlg(signer.algorithm()));
    HeaderParameters p = new HeaderParameters(map);
    return new JwtHeader(serializeParams(p), p);
  }

  private static HeaderParameters parseParams(byte[] header) {
    Map<String, String> map = parseMap(utf8Decode(header));
    return new HeaderParameters(map);
  }

  private static Map<String, String> parseMap(String json) {
    if (json != null) {
      json = json.trim();
      if (json.startsWith("{")) {
        return parseMapInternal(json);
      } else if (StringUtils.isBlank(json)) {
        return new LinkedHashMap<>();
      }
    }
    throw new IllegalArgumentException("Invalid JSON (null)");
  }

  private static Map<String, String> parseMapInternal(String json) {
    Map<String, String> map = new LinkedHashMap<>();
    json = trimLeadingCharacter(trimTrailingCharacter(json, '}'), '{');
    for (String pair : json.split(",")) {
      String[] values = pair.split(":");
      String key = strip(values[0], '"');
      String value = null;
      if (values.length > 0) {
        value = strip(values[1], '"');
      }
      if (map.containsKey(key)) {
        throw new IllegalArgumentException("Duplicate '" + key + "' field");
      }
      map.put(key, value);
    }
    return map;
  }

  private static String strip(String string, char c) {
    return trimLeadingCharacter(trimTrailingCharacter(string.trim(), c), c);
  }

  private static String trimTrailingCharacter(String string, char c) {
    if (string.length() >= 0 && string.charAt(string.length() - 1) == c) {
      return string.substring(0, string.length() - 1);
    }
    return string;
  }

  private static String trimLeadingCharacter(String string, char c) {
    if (string.length() >= 0 && string.charAt(0) == c) {
      return string.substring(1);
    }
    return string;
  }

  private static byte[] serializeParams(HeaderParameters params) {
    StringBuilder builder = new StringBuilder("{");

    appendField(builder, "alg", params.alg);
    if (params.typ != null) {
      appendField(builder, "typ", params.typ);
    }
    for (Entry<String, String> entry : params.map.entrySet()) {
      appendField(builder, entry.getKey(), entry.getValue());
    }
    builder.append("}");
    return utf8Encode(builder.toString());

  }

  private static void appendField(StringBuilder builder, String name, String value) {
    if (builder.length() > 1) {
      builder.append(",");
    }
    builder.append("\"").append(name).append("\":\"").append(value).append("\"");
  }
}

/**
 * Header part of JWT
 */
class JwtHeader implements BinaryFormat {

  private final byte[] bytes;

  final HeaderParameters parameters;

  /**
   * @param bytes the decoded header
   * @param parameters the parameter values contained in the header
   */
  JwtHeader(byte[] bytes, HeaderParameters parameters) {
    this.bytes = bytes;
    this.parameters = parameters;
  }

  @Override
  public byte[] bytes() {
    return bytes;
  }

  @Override
  public String toString() {
    return utf8Decode(bytes);
  }
}

class HeaderParameters {

  final String alg;

  final Map<String, String> map;

  final String typ = "JWT";

  HeaderParameters(String alg) {
    this(new LinkedHashMap<>(Collections.singletonMap("alg", alg)));
  }

  HeaderParameters(Map<String, String> map) {
    String alg = map.get("alg"), typ = map.get("typ");
    if (typ != null && !"JWT".equalsIgnoreCase(typ)) {
      throw new IllegalArgumentException("typ is not \"JWT\"");
    }
    map.remove("alg");
    map.remove("typ");
    this.map = map;
    if (alg == null) {
      throw new IllegalArgumentException("alg is required");
    }
    this.alg = alg;
  }

}

class JwtImpl implements Jwt {

  final JwtHeader header;

  private final byte[] content;

  private final byte[] crypto;

  private String claims;

  /**
   * @param header the header, containing the JWS/JWE algorithm information.
   * @param content the base64-decoded "claims" segment (may be encrypted, depending on header
   * information).
   * @param crypto the base64-decoded "crypto" segment.
   */
  JwtImpl(JwtHeader header, byte[] content, byte[] crypto) {
    this.header = header;
    this.content = content;
    this.crypto = crypto;
    claims = utf8Decode(content);
  }

  /**
   * Validates a signature contained in the 'crypto' segment.
   *
   * @param verifier the signature verifier
   */
  @Override
  public void verifySignature(SignatureVerifier verifier) {
    verifier.verify(signingInput(), crypto);
  }

  private byte[] signingInput() {
    return concat(b64UrlEncode(header.bytes()), JwtHelper.PERIOD,
        b64UrlEncode(content));
  }

  /**
   * Allows retrieval of the full token.
   *
   * @return the encoded header, claims and crypto segments concatenated with "." characters
   */
  @Override
  public byte[] bytes() {
    return concat(b64UrlEncode(header.bytes()), JwtHelper.PERIOD,
        b64UrlEncode(content), JwtHelper.PERIOD, b64UrlEncode(crypto));
  }

  @Override
  public String getClaims() {
    return utf8Decode(content);
  }

  @Override
  public String getEncoded() {
    return utf8Decode(bytes());
  }

  public JwtHeader header() {
    return this.header;
  }

  @Override
  public String toString() {
    return header + " " + claims + " [" + crypto.length + " crypto bytes]";
  }
}
